{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import jieba\n",
    "import warnings\n",
    "from sklearn.feature_extraction.text import TfidfVectorizer\n",
    "from sklearn.naive_bayes import MultinomialNB\n",
    "from sklearn import metrics\n",
    "\n",
    "warnings.filterwarnings('ignore')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Building prefix dict from the default dictionary ...\n",
      "Loading model from cache /var/folders/q4/h78r_bn16mj0ftkwfrvt0x2w0000gn/T/jieba.cache\n",
      "Loading model cost 0.685 seconds.\n",
      "Prefix dict has been built succesfully.\n"
     ]
    }
   ],
   "source": [
    "def cut_words(file_path):\n",
    "    \"\"\"\n",
    "    对文本进行切词\n",
    "    :param file_path: txt文本路径\n",
    "    :return: 用空格分词的字符串\n",
    "    \"\"\"\n",
    "    text_with_spaces = ''\n",
    "    text=open(file_path, 'r', encoding='gb18030').read()\n",
    "    textcut = jieba.cut(text)\n",
    "    for word in textcut:\n",
    "        text_with_spaces += word + ' '\n",
    "    return text_with_spaces\n",
    "\n",
    "def loadfile(file_dir, label):\n",
    "    \"\"\"\n",
    "    将路径下的所有文件加载\n",
    "    :param file_dir: 保存txt文件目录\n",
    "    :param label: 文档标签\n",
    "    :return: 分词后的文档列表和标签\n",
    "    \"\"\"\n",
    "    file_list = os.listdir(file_dir)\n",
    "    words_list = []\n",
    "    labels_list = []\n",
    "    for file in file_list:\n",
    "        file_path = file_dir + '/' + file\n",
    "        words_list.append(cut_words(file_path))\n",
    "        labels_list.append(label)                                                                                                                 \n",
    "    return words_list, labels_list\n",
    "\n",
    "# 训练数据\n",
    "train_words_list1, train_labels1 = loadfile('text classification/train/女性', '女性')\n",
    "train_words_list2, train_labels2 = loadfile('text classification/train/体育', '体育')\n",
    "train_words_list3, train_labels3 = loadfile('text classification/train/文学', '文学')\n",
    "train_words_list4, train_labels4 = loadfile('text classification/train/校园', '校园')\n",
    "\n",
    "train_words_list = train_words_list1 + train_words_list2 + train_words_list3 + train_words_list4\n",
    "train_labels = train_labels1 + train_labels2 + train_labels3 + train_labels4\n",
    "\n",
    "# 测试数据\n",
    "test_words_list1, test_labels1 = loadfile('text classification/test/女性', '女性')\n",
    "test_words_list2, test_labels2 = loadfile('text classification/test/体育', '体育')\n",
    "test_words_list3, test_labels3 = loadfile('text classification/test/文学', '文学')\n",
    "test_words_list4, test_labels4 = loadfile('text classification/test/校园', '校园')\n",
    "\n",
    "test_words_list = test_words_list1 + test_words_list2 + test_words_list3 + test_words_list4\n",
    "test_labels = test_labels1 + test_labels2 + test_labels3 + test_labels4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "stop_words = open('text classification/stop/stopword.txt', 'r', encoding='utf-8').read()\n",
    "stop_words = stop_words.encode('utf-8').decode('utf-8-sig') # 列表头部\\ufeff处理\n",
    "stop_words = stop_words.split('\\n') # 根据分隔符分隔\n",
    "\n",
    "# 计算单词权重\n",
    "tf = TfidfVectorizer(stop_words=stop_words, max_df=0.5)\n",
    "\n",
    "train_features = tf.fit_transform(train_words_list)\n",
    "test_features = tf.transform(test_words_list) # 上面fit过了，这里transform"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 多项式贝叶斯分类器\n",
    "from sklearn.naive_bayes import MultinomialNB  \n",
    "clf = MultinomialNB(alpha=0.001).fit(train_features, train_labels)\n",
    "\n",
    "predicted_labels=clf.predict(test_features)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "准确率为： 0.91\n"
     ]
    }
   ],
   "source": [
    "# 计算准确率\n",
    "print('准确率为：', metrics.accuracy_score(test_labels, predicted_labels))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
